{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8b7772f7-75a1-44ba-b142-2fc02d53c95c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytrt as tp\n",
    "import os\n",
    "import cv2\n",
    "\n",
    "tp.set_log_level(tp.LogLevel.Warning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d50e3d1c-cb64-47c2-b56d-8a355c9f489e",
   "metadata": {},
   "outputs": [],
   "source": [
    "engine_file = \"yolov5m_2.fp32.trtmodel\"\n",
    "if not os.path.exists(engine_file):\n",
    "    tp.compile_onnx_to_file(5, tp.onnx_hub(\"yolov5m\"), engine_file)\n",
    "\n",
    "yolo   = tp.Yolo(engine_file, tp.YoloType.V5, 0, 0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d232b753-0230-4738-b4de-658da99f78dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<pytrt.libpytrtc.Yolo at 0x7f59b7942ab0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "yolo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7a0c7c07-f516-4499-b655-262070f596d6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<Box: left=438.02, top=434.06, right=521.42, bottom=719.22, class_label=27, confidence=0.81457>,\n",
       " <Box: left=136.21, top=200.16, right=1110.42, bottom=713.44, class_label=0, confidence=0.72617>,\n",
       " <Box: left=750.17, top=40.92, right=1141.72, bottom=709.11, class_label=0, confidence=0.90385>]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image  = cv2.imread(\"inference/zand.jpg\")\n",
    "bboxes = yolo.commit(image).get()\n",
    "bboxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f54d7eee-9aaa-4eec-a564-ddc10859b3bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<pytrt.libpytrtc.Yolo at 0x7f59b7942ab0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "yolo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7954f835-6f61-4a92-b6df-da4af3bad37c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "yolo.valid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a0501910-5620-4bdd-96dd-d8ea97d4162d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.models as models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4eb65661-e624-42b1-8002-78a8b8004b22",
   "metadata": {},
   "outputs": [],
   "source": [
    "input = torch.full((5, 3, 224, 224), 0.5).cuda()\n",
    "m = models.resnet18(True).eval().cuda()\n",
    "trt_model = tp.from_torch(m, input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "75ecd5bf-7851-4ed2-aa67-78089c07a14c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-09-03 19:41:12][\u001b[33mwarn\u001b[0m][trt_infer.cpp:26]:NVInfer: The logger passed into createInferBuilder differs from one already provided for an existing builder, runtime, or refitter. TensorRT maintains only a single logger pointer at any given time, so the existing value, which can be retrieved with getLogger(), will be used instead. In order to use a new logger, first destroy all existing builder, runner or refitter objects.\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(8.3447e-06, device='cuda:0', grad_fn=<MaxBackward1>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tout = m(input)\n",
    "rtout = trt_model(input)\n",
    "(tout - rtout).abs().max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "815c9692-3cce-461d-86e2-26c6a4e49fc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2b5b7440-bc64-46ac-81fc-4a677b8fde07",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "277.44 ms\n"
     ]
    }
   ],
   "source": [
    "ntest = 100\n",
    "t0 = time.time()\n",
    "with torch.no_grad():\n",
    "    for i in range(ntest):\n",
    "        tout = m(input)\n",
    "fee = time.time() - t0\n",
    "print(f\"{fee*1000:.2f} ms\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "24f401ef-5a95-4f9f-b101-b6a6292e3f08",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "197.83 ms\n"
     ]
    }
   ],
   "source": [
    "ntest = 100\n",
    "t0 = time.time()\n",
    "with torch.no_grad():\n",
    "    for i in range(ntest):\n",
    "        tout = trt_model(input)\n",
    "fee = time.time() - t0\n",
    "print(f\"{fee*1000:.2f} ms\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "717af853-c527-4652-85a7-5ae1b7f3ed1f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Object `tp.convert_torch_to_trt` not found.\n"
     ]
    }
   ],
   "source": [
    "tp.convert_torch_to_trt?"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PyTorch1.8",
   "language": "python",
   "name": "torch1.8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
